/**
 * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "unit_ecma_test.h"
#include "optimizer/ir/graph_cloner.h"
#include "optimizer/optimizations/cleanup.h"
#include "optimizer/optimizations/phi_type_resolving.h"

namespace ark::compiler {
class PhiTypeResolvingTest : public AsmTest {
public:
    PhiTypeResolvingTest() = default;
};

// NOLINTBEGIN(readability-magic-numbers)
TEST_F(PhiTypeResolvingTest, ResolvePhi)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        CONSTANT(0, 0).i64();
        CONSTANT(1, 1).i64();
        PARAMETER(2, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(3, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(4, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(1);
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, -1)
        {
            INST(8, Opcode::Phi).any().Inputs(3, 4);
            INST(9, Opcode::Return).any().Inputs(8);
        }
    }
    ASSERT_TRUE(graph->RunPass<PhiTypeResolving>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        CONSTANT(0, 0).i64();
        CONSTANT(1, 1).i64();
        PARAMETER(2, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, -1)
        {
            INST(8, Opcode::Phi).i32().Inputs(0, 1);
            INST(10, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(8);
            INST(9, Opcode::Return).any().Inputs(10);
        }
    }

    ASSERT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(PhiTypeResolvingTest, Resolve2Phi)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        CONSTANT(0, 0).i64();
        CONSTANT(1, 1).i64();
        CONSTANT(21, 1).i64();
        PARAMETER(2, 0).any();
        PARAMETER(22, 1).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(3, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(4, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(1);
            INST(10, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(21);
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, 5, 6)
        {
            INST(13, Opcode::Phi).any().Inputs(3, 4);
            INST(11, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(22);
            INST(12, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(11);
        }
        BASIC_BLOCK(6, 5) {}

        BASIC_BLOCK(5, -1)
        {
            INST(8, Opcode::Phi).any().Inputs(13, 10);
            INST(9, Opcode::Return).any().Inputs(8);
        }
    }
    ASSERT_TRUE(graph->RunPass<PhiTypeResolving>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        CONSTANT(0, 0).i64();
        CONSTANT(1, 1).i64();
        CONSTANT(21, 1).i64();
        PARAMETER(2, 0).any();
        PARAMETER(22, 1).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, 5, 6)
        {
            INST(13, Opcode::Phi).i32().Inputs(0, 1);
            INST(11, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(22);
            INST(12, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(11);
        }
        BASIC_BLOCK(6, 5) {}

        BASIC_BLOCK(5, -1)
        {
            INST(8, Opcode::Phi).i32().Inputs(13, 21);
            INST(25, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(8);
            INST(9, Opcode::Return).any().Inputs(25);
        }
    }

    ASSERT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(PhiTypeResolvingTest, ResolvePhiNotApply)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        CONSTANT(0, 0).i64();
        CONSTANT(1, 1.1).f64();
        PARAMETER(2, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(3, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).Inputs(0);
            INST(4, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(1);
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, -1)
        {
            INST(8, Opcode::Phi).any().Inputs(3, 4);
            INST(9, Opcode::Return).any().Inputs(8);
        }
    }
    auto graphOpt = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();

    ASSERT_FALSE(graph->RunPass<PhiTypeResolving>());
    ASSERT_FALSE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    ASSERT_TRUE(GraphComparator().Compare(graph, graphOpt));
}

TEST_F(PhiTypeResolvingTest, ResolvePhiUserNotAnyType)
{
    auto graph = CreateGraphDynWithDefaultRuntime();
    GRAPH(graph)
    {
        CONSTANT(0, 0).f64();
        CONSTANT(1, 1).f64();
        PARAMETER(2, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(3, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(0);
            INST(4, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(1);
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, -1)
        {
            INST(8, Opcode::Phi).any().Inputs(3, 4);
            INST(9, Opcode::CompareAnyType).AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).b().Inputs(8);
            INST(10, Opcode::Return).b().Inputs(9);
        }
    }
    ASSERT_TRUE(graph->RunPass<PhiTypeResolving>());
    ASSERT_TRUE(graph->RunPass<Cleanup>());
    GraphChecker(graph).Check();

    auto graphOpt = CreateGraphDynWithDefaultRuntime();
    GRAPH(graphOpt)
    {
        CONSTANT(0, 0).f64();
        CONSTANT(1, 1).f64();
        PARAMETER(2, 0).any();

        BASIC_BLOCK(2, 3, 4)
        {
            INST(6, Opcode::CastAnyTypeValue).b().AnyType(AnyBaseType::ECMASCRIPT_BOOLEAN_TYPE).Inputs(2);
            INST(7, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(6);
        }
        BASIC_BLOCK(3, 4) {}

        BASIC_BLOCK(4, -1)
        {
            INST(8, Opcode::Phi).f64().Inputs(0, 1);
            INST(11, Opcode::CastValueToAnyType).any().AnyType(AnyBaseType::ECMASCRIPT_DOUBLE_TYPE).Inputs(8);
            INST(9, Opcode::CompareAnyType).AnyType(AnyBaseType::ECMASCRIPT_INT_TYPE).b().Inputs(11);
            INST(10, Opcode::Return).b().Inputs(9);
        }
    }

    ASSERT_TRUE(GraphComparator().Compare(graph, graphOpt));
}
// NOLINTEND(readability-magic-numbers)

}  // namespace ark::compiler
